import numpy as np
cimport numpy as np
from numpy cimport ndarray
cimport cython

ctypedef np.float32_t DTYPE_t

ORACLE_PRECOMPUTED_TABLE = {}

@cython.boundscheck(False)
def decode(int sentence_len, np.ndarray[DTYPE_t, ndim=3] label_scores_chart,
            np.ndarray[DTYPE_t, ndim=3] srl_label_chart,
           np.ndarray[int, ndim=2] srlspan_mat,
           np.ndarray[DTYPE_t, ndim=2] arc_scores_chart,label_vocab, srlspan_vocab):


    cdef DTYPE_t NEG_INF = -np.inf


    cdef np.ndarray[DTYPE_t, ndim=3] value_one_chart = np.zeros((sentence_len+1, sentence_len+1, sentence_len+1), dtype=np.float32)
    cdef np.ndarray[DTYPE_t, ndim=3] value_muti_chart = np.zeros((sentence_len+1, sentence_len+1, sentence_len+1), dtype=np.float32)

    cdef np.ndarray[int, ndim=3] split_idx_chart = np.zeros((sentence_len+1, sentence_len+1, sentence_len+1), dtype=np.int32)
    cdef np.ndarray[int, ndim=3] best_label_chart = np.zeros((sentence_len+1, sentence_len+1, 2), dtype=np.int32)
    cdef np.ndarray[int, ndim=2] best_srlspan_chart = np.zeros((srl_label_chart.shape[0], srl_label_chart.shape[1]), dtype=np.int32)
    cdef np.ndarray[int, ndim=3] head_chart = np.zeros((sentence_len+1, sentence_len+1, sentence_len+1), dtype=np.int32)
    cdef np.ndarray[int, ndim=3] father_chart = np.zeros((sentence_len+1, sentence_len+1, sentence_len+1), dtype=np.int32)


    cdef int length
    cdef int left
    cdef int right

    cdef int child_l
    cdef int child_r
    cdef int child_head
    cdef int child_type
    cdef int type_id
    cdef int srlspan_id


    cdef np.ndarray[DTYPE_t, ndim=1] label_scores_for_span

    cdef int oracle_label_index
    cdef int oracle_type_index
    cdef DTYPE_t label_score_one
    cdef DTYPE_t label_score_empty
    cdef DTYPE_t dep_score
    cdef int argmax_label_index
    cdef int argmax_type_index
    cdef int argmax_srlspan_index
    cdef DTYPE_t left_score
    cdef DTYPE_t right_score
    cdef DTYPE_t type_max_score
    cdef DTYPE_t srlspan_score
    cdef DTYPE_t srlspan_score_verb

    cdef int best_split
    cdef int split_idx # Loop variable for splitting
    cdef DTYPE_t split_val # best so far
    cdef DTYPE_t max_split_val


    cdef int label_index_iter, head, father, verd_iter, srlspan_verb_index

    for length in range(1, sentence_len + 1):
        for left in range(0, sentence_len + 1 - length):
            right = left + length

            argmax_label_index = 1
            if length == 1 or length == sentence_len:
                argmax_label_index = 2 #sub_head label can not be leaf

            label_score_one = label_scores_chart[left, right, argmax_label_index]
            for label_index_iter in range(argmax_label_index, label_scores_chart.shape[2]):
                if label_scores_chart[left, right, label_index_iter] > label_score_one:
                    argmax_label_index = label_index_iter
                    label_score_one = label_scores_chart[left, right, label_index_iter]
            best_label_chart[left, right, 1] = argmax_label_index

            label_score_empty = label_scores_chart[left, right,0]

            srlspan_score = 0
            srlspan_id = srlspan_mat[left, right]

            if srlspan_id != -1 :

                for srlspan_verb_index in range(srl_label_chart.shape[1]):

                    srlspan_score_verb = srl_label_chart[srlspan_id, srlspan_verb_index, 0]
                    argmax_srlspan_index = 0

                    for srlspan_index_iter in range(srl_label_chart.shape[2]):
                        if srl_label_chart[srlspan_id, srlspan_verb_index, srlspan_index_iter] > srlspan_score_verb:
                            argmax_srlspan_index = srlspan_index_iter
                            srlspan_score_verb = srl_label_chart[srlspan_id, srlspan_verb_index, srlspan_index_iter]

                    srlspan_score += srlspan_score_verb
                    best_srlspan_chart[srlspan_id, srlspan_verb_index] = argmax_srlspan_index
            if length == 1:
                #head is right, index from 1
                #leaf may has srl label
                value_one_chart[left, right, right] = label_score_one + srlspan_score
                value_muti_chart[left, right, right] = label_score_empty + srlspan_score
                if value_one_chart[left, right, right] > value_muti_chart[left, right, right]:
                    value_muti_chart[left, right, right] = value_one_chart[left, right, right]
                    best_label_chart[left, right,0] = best_label_chart[left, right,1]
                else:
                    best_label_chart[left, right,0] = 0 #empty label

                head_chart[left, right, right] = -1

                continue

            #head also in the empty part
            for head_l in range(left + 1, right + 1):
                value_one_chart[left, right, head_l] = NEG_INF

            for split_idx in range(left + 1, right):
                for head_l in range(left + 1, split_idx + 1):
                    for head_r in range(split_idx + 1, right + 1):

                        #head in the right empty part, left father is right
                        #left is one, right is multi
                        dep_score = arc_scores_chart[head_l, head_r]
                        if split_idx - left == 1:#leaf can be empty
                            split_val = value_muti_chart[left, split_idx, head_l] + value_muti_chart[split_idx, right, head_r] + dep_score
                        else :
                            split_val = value_one_chart[left, split_idx, head_l] + value_muti_chart[split_idx, right, head_r] + dep_score
                        if split_val > value_one_chart[left, right, head_r]:
                            value_one_chart[left, right, head_r] = split_val
                            split_idx_chart[left, right, head_r] = split_idx
                            head_chart[left, right, head_r] = head_l

                        #head in the left empty part, right father is left
                        #left is multi, right is one
                        dep_score = arc_scores_chart[head_r, head_l]
                        if right - split_idx == 1:#leaf can be empty
                            split_val = value_muti_chart[split_idx, right, head_r] + value_muti_chart[left, split_idx, head_l] + dep_score
                        else:
                            split_val = value_one_chart[split_idx, right, head_r] + value_muti_chart[left, split_idx, head_l] + dep_score
                        if split_val > value_one_chart[left, right, head_l]:
                            value_one_chart[left, right, head_l] = split_val
                            split_idx_chart[left, right, head_l] = split_idx
                            head_chart[left, right, head_l] = head_r

            for head_l in range(left + 1, right + 1):
                if label_score_one + srlspan_score > label_score_empty:
                    value_muti_chart[left, right, head_l] = value_one_chart[left, right, head_l] + label_score_one + srlspan_score
                else :
                    value_muti_chart[left, right, head_l] = value_one_chart[left, right, head_l] + label_score_empty
                    #empty span not need srlspan_sorce, because empty srl label is 0
                value_one_chart[left, right, head_l] = value_one_chart[left, right, head_l] + label_score_one + srlspan_score

            if label_score_one + srlspan_score < label_score_empty:
                best_label_chart[left, right, 0] = 0
            else:
                best_label_chart[left, right,0] = best_label_chart[left, right,1]

    # Now we need to recover the tree by traversing the chart starting at the
    # root. This iterative implementation is faster than any of my attempts to
    # use helper functions and recursion

    # All fully binarized trees have the same number of nodes
    cdef int num_tree_nodes = 2 * sentence_len - 1
    cdef np.ndarray[int, ndim=1] included_i = np.empty(num_tree_nodes, dtype=np.int32)
    cdef np.ndarray[int, ndim=1] included_j = np.empty(num_tree_nodes, dtype=np.int32)

    cdef np.ndarray[int, ndim=1] included_label = np.empty(num_tree_nodes, dtype=np.int32)
    cdef np.ndarray[int, ndim=2] included_srlspan = np.zeros((srl_label_chart.shape[0], srl_label_chart.shape[1]), dtype=np.int32)

    cdef np.ndarray[int, ndim=1] included_father = np.zeros(sentence_len, dtype=np.int32)# 0 is root

    cdef int idx = 0
    cdef int stack_idx = 1
    # technically, the maximum stack depth is smaller than this
    cdef np.ndarray[int, ndim=1] stack_i = np.empty(num_tree_nodes + 5, dtype=np.int32)
    cdef np.ndarray[int, ndim=1] stack_j = np.empty(num_tree_nodes + 5, dtype=np.int32)
    cdef np.ndarray[int, ndim=1] stack_head = np.empty(num_tree_nodes + 5, dtype=np.int32)

    cdef np.ndarray[int, ndim=1] stack_type = np.empty(num_tree_nodes + 5, dtype=np.int32)

    cdef int i, j, k, root_head, nodetype, sub_head

    max_split_val = NEG_INF
    for idxx in range(sentence_len):
        split_val = value_one_chart[0, sentence_len, idxx + 1] + arc_scores_chart[idxx + 1, 0]
        if split_val > max_split_val:
            max_split_val = split_val
            root_head = idxx + 1


    stack_i[1] = 0
    stack_j[1] = sentence_len
    stack_head[1] = root_head
    stack_type[1] = 1

    while stack_idx > 0:

        i = stack_i[stack_idx]
        j = stack_j[stack_idx]
        head = stack_head[stack_idx]
        nodetype = stack_type[stack_idx]
        stack_idx -= 1

        included_i[idx] = i
        included_j[idx] = j

        if i + 1 == j:
            nodetype = 0
        included_label[idx] = best_label_chart[i, j, nodetype]

        srlspan_id = srlspan_mat[i, j]
        if srlspan_id != -1 and (included_label[idx] > 1 or i + 1 == j) :
            #empty not leaf span has no srl label
            for verd_iter in range(srl_label_chart.shape[1]):
                included_srlspan[srlspan_id, verd_iter] = best_srlspan_chart[srlspan_id, verd_iter]

        idx += 1
        if i + 1 < j:

            k = split_idx_chart[i, j, head]
            sub_head = head_chart[i,j, head]
            included_father[sub_head - 1] = head

            stack_idx += 1
            stack_i[stack_idx] = k
            stack_j[stack_idx] = j
            if head > k:
                stack_head[stack_idx] = head
                stack_type[stack_idx] = 0
            else :
                stack_head[stack_idx] = sub_head
                stack_type[stack_idx] = 1
            stack_idx += 1
            stack_i[stack_idx] = i
            stack_j[stack_idx] = k
            if head > k:
                stack_head[stack_idx] = sub_head
                stack_type[stack_idx] = 1
            else :
                stack_head[stack_idx] = head
                stack_type[stack_idx] = 0

    cdef DTYPE_t running_total = 0.0
    for idx in range(num_tree_nodes):
        running_total += label_scores_chart[included_i[idx], included_j[idx], included_label[idx]]

    cdef DTYPE_t score = value_one_chart[0, sentence_len, root_head]

    return score, included_i.astype(int), included_j.astype(int), included_label.astype(int), included_srlspan.astype(int), included_father.astype(int)
